from torchmetrics import MaxMetric, Accuracy, ConfusionMatrix
import logging
import torch
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict
from src.pl_model.distillation import Distilltion
import wandb
import pandas as pd

log = logging.getLogger(__name__)


class MyDistillation(Distilltion):

    def __init__(self, cfg):
        super().__init__(cfg=cfg)

        # if cfg.is_c100:
        #     self.num_classes = 100  # TODO: fix this
        # else:
        self.num_classes = 10

        # self.num_classes = cfg.model.model.num_classes
        total_num_samples_per_class = defaultdict(int)
        data_dists_vectorized = {}
        self.learner_client = self.hparams.learner_client
        self.teacher_client = self.hparams.teacher_client

        # val_ind = cfg.datamodule.learner_val_indices
        # train_cifar10 = torchvision.datasets.CIFAR10('./tmp/data', download=True, transform=self.default_transforms(),
        #                                              train=True)
        # val_set = torch.utils.data.Subset(train_cifar10, indices=val_ind)
        # self.val_loader = torch.utils.data.DataLoader(val_set, batch_size=100, shuffle=False)

        for client, info in self.hparams["clients"].items():
            data_dist = self.hparams["clients"][client]["train_data_distribution"]
            data_dist_vectorized = np.array(
                [data_dist.get(f"{cls_idx}") if data_dist.get(f"{cls_idx}") else 0 for cls_idx in
                 range(self.num_classes)])
            data_dists_vectorized[client] = data_dist_vectorized

            for cls_idx, count in data_dist.items():
                total_num_samples_per_class[cls_idx] += count

        total_num_samples = sum(total_num_samples_per_class.values())

        log.info(data_dists_vectorized)
        log.info(total_num_samples_per_class)
        log.info(total_num_samples)
        self.run_id = cfg.train_exp_id

        # test_cifar10 = torchvision.datasets.CIFAR10('./tmp/data', download=True, transform=self.default_transforms(),
        #                                             train=False)
        # self.test_loader = torch.utils.data.DataLoader(test_cifar10, batch_size=100, shuffle=False)

        self.T = cfg.KL_temperature
        log.info(f"Using Temp = {self.T}")

        if cfg.is_bl:
            self.alpha = 0.5  # same weight for CE and KL_div losses
        else:
            self.alpha = self.get_alpha()

        log.info(f"alpha is: {self.alpha}")


    def on_fit_start(self):
        self.alpha = torch.tensor(self.alpha)
        self.alpha = self.alpha.to(self.device)

    def get_alpha(self):
        ta = self.get_org_per_class_acc(self.teacher_client)
        la = self.get_org_per_class_acc(self.learner_client)
        alpha = np.nan_to_num(ta / (ta + la))
        return alpha


    def training_step(self, batch, batch_idx):  # ,
        x, y = batch
        logits = self(x)
        with torch.no_grad():
            teacher_logits = self.teacher_model(x)

        divergence = F.kl_div(
            F.log_softmax(logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
            reduction='none'
        )

        onehot_y = F.one_hot(y, self.num_classes).to(torch.float)
        ce = F.kl_div(
            F.log_softmax(logits, dim=1),
            onehot_y,
            reduction='none'
        )

        ce = ce * (1 - self.alpha)
        ce = ce.sum() / ce.size()[0]  # by batch size

        divergence = divergence * self.alpha
        divergence = divergence.sum() / divergence.size()[0]

        if not self.hparams.not_multiply_T:
            divergence = divergence * self.T * self.T

        # ce, divergence = reduce_loss(ce, divergence, self.alpha)

        self.log(
            f"{self.exp_name}/learner-kl_loss",
            divergence, on_step=True, on_epoch=False, prog_bar=True
        )

        self.log(
            f"{self.exp_name}/learner-ce_loss",
            ce, on_step=True, on_epoch=False, prog_bar=True
        )

        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, y)
        self.log(
            f"{self.exp_name}/train_acc",
            acc, on_step=False, on_epoch=True, prog_bar=False
        )

        return {"loss": ce + divergence, "preds": preds, "targets": y}

    def get_org_per_class_acc(self, client_idx):
        # run_id = cfg.train_exp_id
        # run_id = self.train_exp_id  # 20 clients
        # run_id = "moh-sands/KDN_N_exp/1hkuxf7t"

        api = wandb.Api()
        run = api.run(self.run_id)
        train_run_entity, train_run_project, train_run_idx = self.run_id.split('/')
        test = "test"

        version = "v0" if test else "latest"
        art = api.artifact(
            f'{train_run_entity}/{train_run_project}/run-{train_run_idx}-{test}_client{client_idx}confusion_matrix_table:{version}')
        table_name = f"{test}_client-{client_idx}/confusion_matrix_table"
        table = art.get(table_name)
        confusion_matrix = self.table_to_dataframe(table)
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        per_class_val_acc = np.diag(confusion_matrix)

        return per_class_val_acc

    def table_to_dataframe(self,table):
        """
        Convert WANDB table to confusion matrix
        Column 1 is actuall class
        Column 2 is predicted class
        Column 3 is number of predictions
        """
        new_data = {
            f"{c[0].split('_')[-1]}": [0] * int(np.sqrt(len(table.data))) for c in table.data
        }
        for row in table.data:
            # new_data[row[0]][int(row[1].split("_")[-1]) - 1] = row[2]
            new_data[f"{row[1].split('_')[-1]}"][int(row[0].split("_")[-1]) - 1] = row[2]
        return pd.DataFrame(new_data, index=list(new_data.keys()))

    # def get_per_class_accuracy(self, teacher=False):
    #
    #     accs_per_label_pct = torch.tensor([0 for c in range(10)])
    #     accs_per_label_pctP = torch.tensor([0 for c in range(10)])
    #     num_of_batches = 0
    #     acc_all = 0.0
    #
    #     with torch.no_grad():
    #         for i, (inputs, labels) in enumerate(self.val_loader):
    #             # inputs = inputs.to(device)
    #             # labels = labels.to(device)
    #
    #             if not teacher:
    #                 outputs = self(inputs)
    #             else:
    #                 outputs = self.teacher_model(inputs)
    #
    #             preds = torch.argmax(outputs, dim=1)
    #
    #             acc = (preds == labels).float().mean()
    #             acc_all += acc
    #
    #             for c in range(10):
    #                 of_c = labels == c
    #                 num_total_per_label = of_c.sum()
    #                 of_c &= preds == labels
    #                 num_corrects_per_label = of_c.sum()
    #                 accs_per_label_pct[c] = (num_corrects_per_label / num_total_per_label * 100)
    #
    #             accs_per_label_pctP = accs_per_label_pctP + accs_per_label_pct
    #             num_of_batches = i + 1
    #
    #         pc_acc = (accs_per_label_pctP / num_of_batches)
    #         acc_all = acc_all / num_of_batches
    #
    #     return pc_acc, acc_all
